package proxy

import (
	"github.com/fatih/color"
	"golang.org/x/crypto/ssh"
	"gosh/pkg/connect"
	"io"
	"net"
	"os"
	"time"
)

func localProxy() {
	host := connect.NewHost(ProxyConfig.Host, ProxyConfig.User, ProxyConfig.Password, ProxyConfig.KeyContent,
		ProxyConfig.KeyPassphrase, ProxyConfig.Port, jumpServer, ProxyConfig.ConnectTimeout)
	err := host.Open()
	if err != nil {
		color.Red("[%s] open ssh connection failed,err:%s\n", time.Now().Format("15:04:05"), err.Error())
		os.Exit(1)
	}

	stop := make(chan bool, 1)
	errMsg := make(chan error, 1)
	go func() {
		for i := 0; i < 5; i++ {
			time.Sleep(time.Second)
			_, err = net.DialTimeout(ProxyConfig.Protocol, ProxyConfig.LocalAddress, ProxyConfig.ConnectTimeout)
			if err == nil {
				color.Green("[%s] proxy connection is ready, please access local address %s\n",
					time.Now().Format("15:04:05"), ProxyConfig.LocalAddress)
				break
			}
		}
		if err != nil {
			stop <- true
			color.Red("[%s] proxy connection failed\n", time.Now().Format("15:04:05"))
			return
		}
	}()
	forward(host.SSHClient, ProxyConfig.Protocol, ProxyConfig.LocalAddress, ProxyConfig.RemoteAddress, stop, errMsg)
}

// 这里面涉及的channel让gc进行回收，手动关闭容易出现panic
// 端口转发
func forward(client *ssh.Client, protocol, localAddr, remoteAddr string, stop chan bool, errMsg chan error) {
	// 打开本地端口
	listener, err := net.Listen(protocol, localAddr)
	if err != nil {
		color.Red("[%s] open local listen address %s failed,err:%s\n", time.Now().Format("15:04:05"), ProxyConfig.LocalAddress, err.Error())
		return
	}
	defer func() { _ = listener.Close() }()
	// 定义异常退出机制，因为设置了 stop chan，为了避免在stop chan阻塞，引入err chan，两者满足其一就能退出
	var errChan = make(chan error, 5)

	// 循环接收本地端口的请求
	go func() {
		for {
			localConn, err := listener.Accept()
			if localConn == nil {
				errMsg <- err
				continue
			}
			if err != nil {
				_ = localConn.Close()
				errMsg <- err
				continue
			}
			go establishLocal(client, protocol, remoteAddr, localConn, errChan)
		}
	}()

	for {
		select {
		case <-stop:
			color.Green("[%s] receive stop signal,process exit\n", time.Now().Format("15:04:05"))
			return
		case err = <-errMsg:
			if err != nil {
				color.Red("[%s] open local listen address %s failed,err:%s\n", time.Now().Format("15:04:05"), ProxyConfig.LocalAddress, err.Error())
			}
			return
		case err = <-errChan:
			if err != nil {
				color.Yellow("[%s] accept request failed, err:%s\n", time.Now().Format("15:04:05"), err.Error())
			}
		}
	}

}

// 处理本地端口的请求
func establishLocal(client *ssh.Client, protocol, remoteAddr string, local net.Conn, errChan chan error) {
	// 打开远程的端口, 每次接收一个新的TCP连接，都得开一次远程转发
	defer local.Close()
	remote, err := client.Dial(protocol, remoteAddr)
	if err != nil {
		errChan <- err
		return
	}
	defer remote.Close()
	defer func() { _ = remote.Close() }()
	errCh := make(chan error, 1)
	go exchangeData(local, remote, errCh)
	go exchangeData(remote, local, errCh)
	<-errCh
	<-errCh
}

type closeWriter interface {
	CloseWrite() error
}

// 数据交换
func exchangeData(r io.Reader, w io.Writer, errCh chan error) {
	_, err := io.Copy(w, r)
	if tcpConn, ok := w.(closeWriter); ok {
		_ = tcpConn.CloseWrite() // 必须要关闭，否则内存泄露
	}
	errCh <- err
}
